- 
                Notifications
    You must be signed in to change notification settings 
- Fork 146
Support Scan with MIT-MOT in JAX backend #1651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c12b78a    to
    d386dc5      
    Compare
  
    | For completeness, it would be good to benchmark a realistic model with multiple taps and one where only the last state is used. We could do a hidden markov model with multiple lags. Any idea for a recursive one where only the last state matters @jessegrabowski ? | 
| I think for that we should try a recursive algorithm, like the newton solver in this PR, or this eigenvalue deflation algorithm | 
| Also this refactor makes me ponder. Do we really need the distinction between nitsot, sitsot, mitsot,and mitmot internally? mitmots are general enough to cover all the other cases. What are we gaining with this distinction? Even sequences can be understood as a mitmot (the flipside of a nitsot: something with input taps but no output taps) | 
| I guess the ability to rewrite into specialized forms/gradients in each case? We don't take advantage of that at all though | 
| I think it was just piecemeal design, and they didn't start over once they figured out they needed mitmot. The whole taps thing (specially with gaps) is also a bit silly. Who needs taps -1 and -3 but not -2? just have it as a useless input in the inner function then | 
| I added the cycle detection example, for that one our new implementation isn't hurting direct jax much. backward is a bit slower, both is a bit faster | 
d386dc5    to
    e80e73a      
    Compare
  
    | I added a new commit that uses the original approach for JAX for all outputs, and the new one only for mit-mot. This way there is no performance regression when using jax directly, and it doesn't seem like we hurt the pytensor approach either. It's just less elegant I want to try one last tweak to mit-mot. I reckon we can use JAX tracing for the oldest affected tap, and just append the final carry at the end | 
24cf099    to
    9938175      
    Compare
  
    | Okay I think I'm done. Users shouldn't be too penalized by using PyTensor backend, and may even gain big when JAX screws up (as in that value_and_grad for the SEIR model). The current approach should not incur any regression, as Scan without MIT-MOT are defined pretty much the same as before, plus or minus some cosmetic differences. I left the "pure" implementation as an intermediate commit and left a note in the dispatch function header. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements support for MIT-MOT (Multiple Input Taps - Multiple Output Taps) in the JAX backend for PyTensor's Scan operation. The implementation rewrites the scan dispatch to work more like PyTensor's internal API rather than the user-facing API, using a tape-based approach that reads from and writes to specific buffer locations during iteration.
Key changes:
- Refactored JAX scan implementation to handle MIT-MOT by treating them as buffers with read/write operations
- Added new method inner_mitmot_outs_groupedto group MIT-MOT outputs by variable
- Enhanced test coverage with MIT-MOT specific tests and comprehensive benchmarks
Reviewed Changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated 5 comments.
| File | Description | 
|---|---|
| tests/link/jax/test_scan.py | Adds MIT-MOT test case, removes old SEIR test, adds comprehensive benchmarking suite | 
| pytensor/scan/op.py | Adds helper method to group MIT-MOT outputs by variable instead of flattening | 
| pytensor/link/jax/dispatch/scan.py | Complete rewrite of JAX scan dispatch with MIT-MOT support using buffer-based approach | 
| pytensor/compile/io.py | Minor fix to handle None update values in string representation | 
a43debe    to
    77615ff      
    Compare
  
    | Codecov Report❌ Patch coverage is  
 ❌ Your patch check has failed because the patch coverage (91.66%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@            Coverage Diff             @@
##             main    #1651      +/-   ##
==========================================
- Coverage   81.58%   81.57%   -0.02%     
==========================================
  Files         240      240              
  Lines       53593    53596       +3     
  Branches     9454     9455       +1     
==========================================
- Hits        43722    43719       -3     
- Misses       7395     7399       +4     
- Partials     2476     2478       +2     
 🚀 New features to boost your workflow:
 | 
| Woohoo another scan PR 🥳 I will take a closer look later but would it be possible to add a small docstring to the cycle detection test to explain what it is about, in the same way as the SEIR test? | 
| 
 Maybe @jessegrabowski can volunteer it since it's his model. Also have to add him as co-author, forgot about that | 
| It should be called  It solves the matrix equation  | 
| I'll mention that for our purposes it's testing a case where only the last state of the Scan is needed by the user | 
PyTensor Variables cannot be called `bool` upon
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
77615ff    to
    999e1c8      
    Compare
  
    | Updated | 
Using JAX Scan machinery to create MIT-SOT, SIT-SOT, and NIT-SOT buffers for us seems to be more performant than working directly on the pre-allocated buffers and reading/writing at every iteration. There is no machinery to work with MIT-MOT directly (just like in PyTensor user-facing Scan).
999e1c8    to
    be70871      
    Compare
  
    
Closes #264
Rewrite scan dispatch to look more like PyTensor's internal API and less like the user-facing API. The old approach didn't allow us to implement MIT-MOT because they have no equivalent in the user-facing API of Scan (both PyTensor and JAX's). They show only in the autodiff.
The new scan implementation works with the concept of tapes and reading/writing to specific locations as it loops, which is how it looks internally in PyTensor. This makes it straightforward to implement MIT-MOTs as it's just a case of reading multiple places (like MIT-SOT), but updating multiple places (unlike MIT-SOT/SIT-SOT/NIT-SOT).
Note: We may not want to use this representation for Scans without MIT-MOT. The old approach may be better optimized/auto-diffed by JAX, since it mimicks how users would use their API. Benchmarks needed!
Updated: Hybrid is indeed better
Benchmark resulst
We're hurting the pure JAX forward (2x) and backward (10x), but we do value_and_grad better than JAX (3-4x) [just to trigger @aseyboldt I guess]
Supersedes something
Closes something else